-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Use int_scaled_matmul
with int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC)
#1402
Conversation
…mmetrically quantized weights
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1402
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -888,14 +888,19 @@ def _choose_qparams_affine( | |||
"preserve_zero == False is not supported for symmetric quantization" | |||
) | |||
if ( | |||
zero_point_domain is not None | |||
zero_point_domain != ZeroPointDomain.NONE.name | |||
and zero_point_domain != None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel we can probably remove support for None since it's the same as ZeroPointDomain.NONE.name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again for reviewing!
Some other places in the codebase are also using both ZeroPointDomain.NONE.name
and None separately:
ao/torchao/quantization/quant_primitives.py
Line 543 in 31234db
), f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}" |
I unified these two cases as one in the latest commit, but I'm not sure if changes in __tensor_unflatten__
& __tensor_flatten__
methods of some classes may be required at some other places in the codebase to ensure that they can deal with a None
zero-point when TorchDynamo would be used . I'll run CUDA-only UTs at my end tomorrow morning to verify.
EDIT: Haven't gotten access to an Nvidia GPU until now
) | ||
if zero_point_domain == ZeroPointDomain.NONE.name: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the fix, looks like this is not tested before. can you add a test for the new code path?
also this op is becoming too complicated..we want to split
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a test for the new code path?
This case is being tested in a UT I added in test/quantization/test_quant_primitives.py
also this op is becoming too complicated..we want to split
Please advise if you're referring to splitting _choose_qparams_affine
.
If so, I could split it up into smaller methods. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I meant splitting choose_qparams_affine/quantize_affine/dequantize, not to smaller methods, but to different variations and reduce the complexity of the most common path (and remove these if/else checking), this includes removing preserve_zero, zero_point_domain args and just have different variations of choose_qparams_affine/quantize_affine/dequantize. this should be done separately though since it will be a large change
Closing in favor of #1556, which fixes the ZeroPointDomain.NONE implementation. Thanks! |
Feature
Use
int_scaled_matmul
with asymmetrically int8 quantized activation & symmetrically int8 quantized weight by applying compensation for zero point of activation.Motivation
Currently, optimizing GEMMs by using asymmetrically int8 quantized activation & symmetrically quantized weight with the
int8_dynamic_activation_int8_weight
API poses a problem - torchao currently does not usetorch._int_mm
for this case, so in case of frozen weights (inference, and Inductor freezing config enabled), the frozen int8 weights are folded into FP32 weights (asaten.mm
's second argument) during Inductor's constant-folding passes (during freezing).With
sym act, sym wgt
case,torch._int_mm
is being used, and that makes it easier to leverage frozen int8 weights with Inductor pattern-matching & use Inductor max-autotune mode.This PR does something similar for
asym act, sym wgt
case by usingtorch._int_mm
with int8 activation & weights, and applying compensation corresponding to the activation's zero points.This change makes it possible to leverage this GEMM's pattern with Inductor pattern-matching & using Inductor max-autotune to fuse the whole GEMM (I'd add its support in a PyTorch PR).
TODO